import os
from collections import defaultdict
from Bio import SeqIO
from pylab import *
from matplotlib import ticker


def kilo(x, pos):
    if x > 1000:
        return '%dk' % (x//1000)
    else:
        return '%d' % x

formatter = ticker.FuncFormatter(kilo)


def read_scrna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "scRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    biogenesis = {}
    for record in records:
        description, category = record.description.rsplit(", ", 1)
        name, description = description.split(None, 1)
        assert name == record.id
        assert description.endswith(")")
        index = description.rindex("(")
        gene = description[index+1:-1]
        description = description[:index].rstrip()
        if category == "long non-coding RNA":
            assert description == "Homo sapiens brain cytoplasmic RNA 1"
            assert gene == "BCYRN1"
            biogenesis[record.id] = ("brain cytoplasmic RNA 1", "Pol-III")
        elif category == "non-coding RNA":
            assert description == "Homo sapiens RNA, 7SL, cytoplasmic 832, pseudogene"
            assert gene == "RN7SL832P"
            biogenesis[record.id] = ("7SL", "Pol-III")
        else:
            assert category == "small cytoplasmic RNA"
            prefix = "Homo sapiens RNA component of signal recognition particle 7SL"
            if description.startswith(prefix):
                assert gene in ("RN7SL1", "RN7SL2", "RN7SL3")
                biogenesis[record.id] = ("7SL", "Pol-III")
            elif description == "Homo sapiens MALAT1-associated small cytoplasmic RNA":
                assert gene == "MASCRNA"
                biogenesis[record.id] = (gene, "intronic")
            else:
                raise Exception("Unknown gene '%s' with description '%s'" % (gene, description))
    return biogenesis

def read_snrna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "snRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    snrna_annotations = {}
    for record in records:
        if record.id.startswith("ENST"):
            assert record.description.startswith(record.id)
            description = record.description[len(record.id)+1:]
            name, description, rfam = description.rsplit("|")
            snRNA = None
            biogenesis = None
            if rfam == "U1 spliceosomal RNA":
                snRNA = "U1"
            elif rfam == "U2 spliceosomal RNA":
                snRNA = "U2"
            elif rfam == "U4 spliceosomal RNA":
                snRNA = "U4"
            elif rfam == "U4atac minor spliceosomal RNA":
                snRNA = "U4atac"
            elif rfam == "U5 spliceosomal RNA":
                snRNA = "U5"
            elif rfam == "U6 spliceosomal RNA":
                snRNA = "U6"
            elif rfam == "U6atac minor spliceosomal RNA":
                snRNA = "U6atac"
            elif rfam == "U7 small nuclear RNA":
                snRNA = "U7"
            elif rfam == "U11 spliceosomal RNA":
                snRNA = "U11"
            elif rfam == "U12 minor spliceosomal RNA":
                snRNA = "U12"
            elif description == "Homo sapiens (human) U1 spliceosomal RNA":
                snRNA = "U1"
            elif description == "Homo sapiens U6 spliceosomal RNA":
                snRNA = "U6"
            elif description == "Homo sapiens (human) U6 spliceosomal RNA":
                snRNA = "U6"
            elif description == "Homo sapiens (human) U6 spliceosomal RNA (multiple genes)":
                snRNA = "U6"
            elif description.startswith("Homo sapiens (human) U1 spliceosomal RNA"):
                snRNA = "U1"
            elif (not description and not rfam):
                if name.startswith("RNU4ATAC"):
                    snRNA = "U4atac"
                elif name.startswith("RNU6ATAC"):
                    snRNA = "U6atac"
                elif name.startswith("RNA, U5A small nuclear "):
                    snRNA = "U5"
                elif name.startswith("RNA, U6 small nuclear "):
                    snRNA = "U6"
                else:
                    try:
                        gene, number = name.split("-")
                    except ValueError:
                        pass
                    else:
                        if gene == "RNU1":
                            snRNA = "U1"
                        elif gene == "RNU2":
                            snRNA = "U2"
                        elif gene == "RNU4":
                            snRNA = "U4"
                        elif gene in ("RNU5A", "RNU5B", "RNU5D", "RNU5F"):
                            snRNA = "U5"
                        elif gene == "RNU6":
                            snRNA = "U6"
                        elif gene == "RNU7":
                            snRNA = "U7"
            if snRNA is None:
                if record.id == 'ENST00000620626.1':
                    snRNA = "U4"
                elif record.id == 'ENST00000636749.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000636931.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000636425.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000620349.1':
                    snRNA = "U4"
                elif record.id == 'ENST00000637085.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000516584.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000516940.1':
                    snRNA = "U7"
                elif record.id == 'ENST00000637295.1':
                    snRNA = "U2"
                elif record.id == 'ENST00000636829.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000619968.1':
                    snRNA = "U7"
                elif record.id == 'ENST00000618345.1':
                    snRNA = "U4atac"
                elif record.id == 'ENST00000516898.1':
                    snRNA = "U11"
                elif record.id == 'ENST00000619194.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000647487.1':
                    snRNA = "U7"
                elif record.id == 'ENST00000646220.1':
                    snRNA = "U6"
                elif record.id == "ENST00000516898.1":
                    snRNA = "U11"
        else:
            description, category = record.description.rsplit(", ", 1)
            name, description = description.split(None, 1)
            assert name == record.id
            assert description.endswith(")")
            index = description.rindex("(")
            gene = description[index+1:-1]
            description = description[:index].rstrip()
            if category == "long non-coding RNA":
                assert description == "Homo sapiens brain cytoplasmic RNA 1"
                assert gene == "BCYRN1"
                snRNA = gene
                biogenesis = "Pol-III"
            elif category == "non-coding RNA":
                assert description == "Homo sapiens RNA, 7SL, cytoplasmic 832, pseudogene"
                assert gene == "RN7SL832P"
                snRNA = gene
                biogenesis = "Pol-III"
            else:
                assert category == "small nuclear RNA"
                if description == 'Homo sapiens RNA component of 7SK nuclear ribonucleoprotein':
                    assert gene == 'RN7SK'
                    snRNA = "7SK"
                    biogenesis = "Pol-III"
                elif description.startswith('Homo sapiens RNA, U1 small nuclear '):
                    assert gene.startswith('RNU1-')
                    snRNA = "U1"
                elif description.startswith('Homo sapiens RNA, variant U1 small nuclear '):
                    assert gene.startswith('RNVU1-')
                    snRNA = "U1"
                elif description.startswith('Homo sapiens RNA, U2 small nuclear '):
                    assert gene.startswith('RNU2-')
                    snRNA = "U2"
                elif description.startswith('Homo sapiens RNA, U4 small nuclear '):
                    assert gene.startswith('RNU4-')
                    snRNA = "U4"
                elif description.startswith('Homo sapiens RNA, U4atac small nuclear '):
                    assert gene == 'RNU4ATAC'
                    snRNA = "U4atac"
                elif description.startswith('Homo sapiens RNA, U5A small nuclear '):
                    assert gene.startswith('RNU5A-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5B small nuclear '):
                    assert gene.startswith('RNU5B-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5D small nuclear '):
                    assert gene.startswith('RNU5D-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5E small nuclear '):
                    assert gene.startswith('RNU5E-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5F small nuclear '):
                    assert gene.startswith('RNU5F-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U6 small nuclear '):
                    assert gene.startswith('RNU6-')
                    snRNA = "U6"
                elif description.startswith('Homo sapiens RNA, U6atac small nuclear '):
                    assert gene == 'RNU6ATAC'
                    snRNA = "U6atac"
                elif description.startswith('Homo sapiens RNA, U7 small nuclear '):
                    assert gene.startswith('RNU7-')
                    snRNA = "U7"
                elif description == 'Homo sapiens RNA, U11 small nuclear':
                    assert gene == 'RNU11'
                    snRNA = "U11"
                elif description == 'Homo sapiens RNA, U12 small nuclear':
                    assert gene == 'RNU12'
                    snRNA = "U12"
                else:
                    raise Exception("Unknown gene '%s' with description '%s'" % (gene, description))
        if snRNA in ("U1", "U2", "U4", "U4atac", "U5", "U7", "U11", "U12"):
            snRNA = "%s spliceosomal RNA" % snRNA
            biogenesis = "Pol-II"
        elif snRNA in ("U6", "U6atac"):
            snRNA = "%s spliceosomal RNA" % snRNA
            biogenesis = "Pol-III"
        elif biogenesis is None:
            raise Exception("Unknown snRNA %s with description '%s'" % (snRNA, description))
        snrna_annotations[record.id] = (snRNA, biogenesis)
    return snrna_annotations

def read_snorna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "snoRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    snorna_annotations = {}
    for record in records:
        if record.id.startswith("NR_"):
            description = record.description
            if "C/D box 3A" in description:
                snoRNA = "U3"
            elif "C/D box 3B" in description:
                snoRNA = "U3"
            elif "C/D box 3C" in description:
                snoRNA = "U3"
            elif "C/D box 3D" in description:
                snoRNA = "U3"
            elif "C/D box 3E" in description:
                snoRNA = "U3"
            elif "C/D box 3F" in description:
                snoRNA = "U3"
            elif "C/D box 3G" in description:
                snoRNA = "U3"
            elif "C/D box 3H" in description:
                snoRNA = "U3"
            elif "C/D box 3I" in description:
                snoRNA = "U3"
            elif "C/D box 3J" in description:
                snoRNA = "U3"
            elif "C/D box 3K" in description:
                snoRNA = "U3"
            elif "C/D box 3 pseudogene" in description:
                snoRNA = "U3"
            elif "C/D box 118" in description:
                snoRNA = "U8"
            elif "C/D box 13 " in description:
                snoRNA = "U13"
            elif "C/D box 13A" in description:
                snoRNA = "U13"
            elif "C/D box 13B" in description:
                snoRNA = "U13"
            elif "C/D box 13C" in description:
                snoRNA = "U13"
            elif "C/D box 13D" in description:
                snoRNA = "U13"
            elif "C/D box 13E" in description:
                snoRNA = "U13"
            elif "C/D box 13F" in description:
                snoRNA = "U13"
            elif "C/D box 13G" in description:
                snoRNA = "U13"
            elif "C/D box 13H" in description:
                snoRNA = "U13"
            elif "C/D box 13I" in description:
                snoRNA = "U13"
            elif "C/D box 13J" in description:
                snoRNA = "U13"
            else:
                assert record.description.startswith(record.name)
                description = record.description[len(record.name):].strip()
                terms = description.split(", ")
                assert len(terms) == 3
                assert terms[0] in ("Homo sapiens small nucleolar RNA", "Homo sapiens RNA")
                assert terms[2] == "small nucleolar RNA"
                name, symbol = terms[1].rsplit(None, 1)
                assert symbol.startswith("(")
                assert symbol.endswith(")")
                if name == "U105B small nucleolar":
                    snoRNA = "H/ACA box"  # According to snoDB
                elif name == "U105C small nucleolar":
                    snoRNA = "H/ACA box"  # According to snoDB
                else:
                    word1, word2 = name.rsplit(None, 1)
                    assert word1 in ("C/D box", "H/ACA box")
                    snoRNA = word1
        else:
            assert record.id.startswith("ENST")
            try:
                name, description = record.description.split(None, 1)
            except ValueError:
                snoRNA = None
            else:
                snoRNA = None
                if description.startswith("ENSG"):
                    gene, description = description.split(None, 1)
                if description.startswith("H/ACA box Small nucleolar RNA SNORA"):
                    snoRNA = "H/ACA box"
                elif description.startswith("C/D box Small nucleolar RNA SNORD"):
                    snoRNA = "C/D box"
                elif description.startswith("H/ACA box Small nucleolar RNA ACA"):
                    snoRNA = "H/ACA box"
                elif description == "H/ACA box Small nucleolar RNA U109":
                    snoRNA = "H/ACA box"
                elif description.startswith("H/ACA box SNORA"):
                    snoRNA = "H/ACA box"
                elif description.startswith("H/ACA box ACA"):
                    snoRNA = "H/ACA box"
                elif description.startswith("C/D box SNORD"):
                    snoRNA = "C/D box"
                elif description.startswith("C/D box Small nucleolar RNA U2-"):
                    snoRNA = "C/D box"
                elif description == "C/D box Small nucleolar RNA Z40":
                    snoRNA = "C/D box"
                elif description.startswith("C/D box sno"):
                    snoRNA = "C/D box"
                elif description.startswith("C/D box Small nucleolar RNA U83B"):
                    snoRNA = "C/D box"
                elif description.startswith("C/D box Small nucleolar RNA MBII"):
                    snoRNA = "C/D box"
                elif description == "C/D box Small nucleolar RNA U3":
                    snoRNA = "U3"
                elif description == "C/D box U8":
                    snoRNA = "U8"
                elif description == "C/D box Small nucleolar RNA U13":
                    snoRNA = "U13"
                if snoRNA is None:
                    raise Exception("Unknown snoRNA with description '%s'" % record.description)
                    continue
        if snoRNA in ("U3", "U8", "U13"):
            biogenesis = "Pol-II"
        elif snoRNA in ("C/D box", "H/ACA box"):
            biogenesis = "intronic"
        else:
            raise Exception("Unknown snoRNA %s" % snoRNA)
        snoRNA = "%s snoRNA" % snoRNA
        snorna_annotations[record.id] = (snoRNA, biogenesis)
    return snorna_annotations

def read_scarna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "scaRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    scarna_annotations = {}
    for record in records:
        assert record.id.startswith("NR_")
        description = record.description
        if "small Cajal body-specific RNA 2" in description:
            biogenesis = "Pol-II"
        elif "small Cajal body-specific RNA 17" in description:
            biogenesis = "Pol-II"
        else:
            biogenesis = "intronic"
        gene = "small Cajal body-specific RNA"
        scarna_annotations[record.id] = (gene, biogenesis)
    return scarna_annotations

def read_rna_sizes(dataset):
    scrna_biogenesis = read_scrna_annotations()
    snrna_biogenesis = read_snrna_annotations()
    snorna_biogenesis = read_snorna_annotations()
    scarna_biogenesis = read_scarna_annotations()
    filename = "rnasize.%s.txt" % dataset
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    assert line.startswith("#")
    words = line[1:].split()
    assert words[0] == 'rank'
    assert words[1] == 'annotation'
    assert words[2] == 'transcript'
    rnasizes = words[3:]
    assert len(rnasizes) == 1001
    for i in range(1000):
        assert rnasizes[i] == str(i)
    assert rnasizes[1000] == ">999"
    lengths = arange(1000)
    counts = defaultdict(lambda: zeros(1001))
    for number, line in enumerate(handle):
        words = line.split()
        assert int(words[0]) == number
        annotation = words[1]
        transcript = words[2]
        values = array(words[3:], float)
        assert len(values) == len(lengths) + 1
        if annotation in ("RPPH", "RMRP", "yRNA", "vRNA", "snar", "tRNA"):
            annotation = "Pol-III short RNA"
        elif annotation == "scRNA":
            gene, biogenesis = scrna_biogenesis[transcript]
            if biogenesis == "intronic":
                annotation = "intronic short RNA"
            elif biogenesis == "Pol-III":
                annotation = "Pol-III short RNA"
            else:
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation == "snRNA":
            gene, biogenesis = snrna_biogenesis[transcript]
            if biogenesis == "Pol-II":
                annotation = "Pol-II short RNA"
            elif biogenesis == "Pol-III":
                annotation = "Pol-III short RNA"
            else:
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation == "snoRNA":
            gene, biogenesis = snorna_biogenesis[transcript]
            if biogenesis == "Pol-II":
                annotation = "Pol-II short RNA"
            elif biogenesis == "intronic":
                annotation = "intronic short RNA"
            else:
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation == "scaRNA":
            gene, biogenesis = scarna_biogenesis[transcript]
            if biogenesis == "Pol-II":
                annotation = "Pol-II short RNA"
            elif biogenesis == "intronic":
                annotation = "intronic short RNA"
            else:
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation in ("pretRNA", "presnRNA", "presnoRNA", "prescaRNA"):
            annotation = "short RNA precursor"
        counts[annotation][lengths] += values[:-1]
    handle.close()
    return counts

colors = {"chrM": "gray",
          "rRNA": "darkgray",
          "Pol-II short RNA": "forestgreen",
          "Pol-III short RNA": "limegreen",
          "intronic short RNA": "mediumspringgreen",
          "short RNA precursor": "deepskyblue",
          "histone": "navy",
          "sense_proximal": "indianred",
          "sense_upstream": "brown",
          "sense_distal": "firebrick",
          "sense_distal_upstream": "maroon",
          "prompt": "chocolate",
          "antisense": "saddlebrown",
          "antisense_distal": "sandybrown",
          "antisense_distal_upstream": "peachpuff",
          "FANTOM5_enhancer": "gold",
          "roadmap_enhancer": "goldenrod",
          "roadmap_dyadic": "khaki",
          "novel_enhancer_CAGE": "darkgoldenrod",
          "novel_enhancer_HiSeq": "orange",
          "other_intergenic": "whitesmoke",
         }

annotations = tuple(colors.keys())

offset = 128
xmin = 180 - offset
xmax = 420 - offset

dataset = "MiSeq"
counts = read_rna_sizes(dataset)
assert counts.keys() == colors.keys()
fig = figure(figsize=(6,12))
ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_xlabel("RNA size [bp]")
ax.set_ylabel("Read count")
ax.set_title(dataset, pad=16)
for i, category in enumerate(colors.keys()):
    ax = fig.add_subplot(7,3,i+1)
    for x in range(xmin, xmax+1):
        count = counts[category][x]
        if count > 0:
            plot([x, x], [0, count], color='black')
    ymin, ymax = ylim()
    ymin = 0
    plot([200-offset, 200-offset], [ymin, ymax], 'r--')
    plot([400-offset, 400-offset], [ymin, ymax], 'r--', label='Selected size limits')
    if i in (18, 19, 20):
        sizes = (200-offset, 100, 200, 400-offset)
        xticks(sizes, fontsize=8)
    else:
        xticks([])
    ax.yaxis.set_major_formatter(formatter)
    yticks(fontsize=8)
    title(category, fontsize=8, pad=2)
    xlim(xmin, xmax)
    ylim(ymin, ymax)
subplots_adjust(bottom=0.07, top=0.95, left=0.09, right=0.98, wspace=0.32)
filename = "figure_rna_sizes_%s.png" % dataset
print("Saving figure as %s" % filename)
savefig(filename)
